# core/formalization/symbol_manager.py
from typing import Dict, Set, List, Optional, Any, Union
import json
import hashlib
import os
from collections import defaultdict
import uuid

from llm.auxiliary import Auxiliary
from utils.logger import Logger
import core.agent_prompt as AgentPrompt
from utils.json_utils import extract_json
from llm.message import (
    Message,
    MessageContent,
    ROLE_SYSTEM,
    ROLE_USER,
    ROLE_ASSISTANT,
    TYPE_SETTING,
    TYPE_CONTEXT,
    TYPE_CONTENT,
)

class SymbolManager:
    SYMBOLIC = "symbolic"
    LOGICAL = "logical" 
    MATHEMATICAL = "mathematical"
    DOMAIN_SPECIFIC = "domain_specific"
    
    def __init__(self, logger: Logger, auxiliary: Auxiliary, config: Dict={}):
        self.logger = logger
        self.auxiliary = auxiliary
        self.config = config
        self.category_cache = {}

    def get_representations_for_terms(self, term_infos: List[Dict], representation_type: str, category: str) -> Dict[str, Dict]:

        if category not in self.category_cache:
            self._load_cache_for_category(category)
        
        concepts = self.category_cache[category].get('concepts', {})
        term_to_concept = self.category_cache[category].get('term_to_concept', {})
        used_symbols = self.category_cache[category].get('used_symbols', {
            self.SYMBOLIC: set(),
            self.LOGICAL: set(),
            self.MATHEMATICAL: set(),
            self.DOMAIN_SPECIFIC: set()
        })

        result_representations = {}
        need_generated_terms = []
        
        for term_info in term_infos:
            term = term_info['term']
            concept_id = term_to_concept.get(term)
            
            if concept_id and concept_id in concepts:
                concept = concepts[concept_id]
                if representation_type in concept.get('representations', {}):
                    result_representations[term] = {
                        'concept_id': concept_id,
                        'representation': concept['representations'][representation_type],
                        'synonyms': concept.get('synonyms', []),
                        'original_terms': concept.get('original_terms', [])
                    }
                    continue
            
            need_generated_terms.append(term_info)
        
        if not need_generated_terms:
            return result_representations
        
        new_results = self._query_llm_for_representations(
            need_generated_terms, 
            representation_type,
            concepts,
            term_to_concept,
            used_symbols
        )

        if not new_results:
            return result_representations
        
        for result in new_results:
            self._process_and_store_result(
                result, 
                representation_type,
                concepts,
                term_to_concept,
                used_symbols
            )
            
        for term_info in need_generated_terms:
            term = term_info['term']
            concept_id = term_to_concept.get(term)
            if concept_id and concept_id in concepts:
                concept = concepts[concept_id]
                if representation_type in concept.get('representations', {}):
                    result_representations[term] = {
                        'concept_id': concept_id,
                        'representation': concept['representations'][representation_type],
                        'synonyms': concept.get('synonyms', []),
                        'original_terms': concept.get('original_terms', [])
                    }
        
        self.category_cache[category] = {
            'concepts': concepts,
            'term_to_concept': term_to_concept,
            'used_symbols': used_symbols
        }
        
        self._save_cache_for_category(category)
        return result_representations

    def _process_and_store_result(
        self, 
        result: Dict, 
        representation_type: str,
        concepts: Dict[str, Dict],
        term_to_concept: Dict[str, str],
        used_symbols: Dict[str, Set[str]]
    ):
        term = result.get('term')
        synonyms = result.get('synonyms', [])
        all_terms = [term] + synonyms
        
        existing_concept_id = None
        for t in all_terms:
            if t in term_to_concept:
                existing_concept_id = term_to_concept[t]
                break
        
        if existing_concept_id:
            concept = concepts[existing_concept_id]
            concept['representations'][representation_type] = {
                'symbol': result.get('symbol'),
                'explanation': result.get('explanation', ''),
                'formula': result.get('formula', ''),
                'notation': result.get('notation', '')
            }
            
            all_existing_terms = set(concept.get('original_terms', []) + concept.get('synonyms', []))
            new_synonyms = [t for t in all_terms if t not in all_existing_terms]
            concept['synonyms'].extend(new_synonyms)
            
            for t in all_terms:
                term_to_concept[t] = existing_concept_id
                
        else:
            concept_id = str(uuid.uuid4())
            concepts[concept_id] = {
                'original_terms': [term],
                'synonyms': synonyms,
                'representations': {
                    representation_type: {
                        'symbol': result.get('symbol'),
                        'explanation': result.get('explanation', ''),
                        'formula': result.get('formula', ''),
                        'notation': result.get('notation', '')
                    }
                },
                'type': result.get('type', 'Unknown'),
                'context': result.get('context', '')
            }
            
            for t in all_terms:
                term_to_concept[t] = concept_id
        
        symbol = result.get('symbol')
        if symbol:
            used_symbols[representation_type].add(symbol)

    def _query_llm_for_representations(
        self, 
        term_infos: List[Dict], 
        representation_type: str,
        concepts: Dict[str, Dict],
        term_to_concept: Dict[str, str],
        used_symbols: Dict[str, Set[str]]
    ) -> List[Dict]:
        enriched_term_infos = []
        for term_info in term_infos:
            term = term_info['term']
            enriched_info = term_info.copy()
            
            concept_id = term_to_concept.get(term)
            if concept_id and concept_id in concepts:
                concept = concepts[concept_id]
                existing_representations = concept.get('representations', {})
                synonyms = concept.get('synonyms', []) + concept.get('original_terms', [])
                
                enriched_info.update({
                    'synonyms': list(set(synonyms)),
                    'existing_representations': existing_representations,
                    'concept_id': concept_id
                })
            else:
                enriched_info.update({
                    'synonyms': [],
                    'existing_representations': {},
                    'concept_id': None
                })
                
            enriched_term_infos.append(enriched_info)
        
        max_retries = 3
        for attempt in range(max_retries):
            try:
                return self._get_symbol(representation_type, used_symbols, enriched_term_infos, concepts)
            except Exception as e:
                self.logger.log_exception(e)
                if attempt == max_retries - 1:
                    return []
        return []
        
    def _get_symbol(self, representation_type, used_symbols, enriched_term_infos, concepts):
        prompt = AgentPrompt.search_representation_prompt(
            representation_type.lower(), 
            used_symbols[representation_type], 
            enriched_term_infos, 
            self._get_existing_concepts(concepts)
        )
    
        messages = [Message(ROLE_USER, [MessageContent(TYPE_CONTENT, prompt)])]
        response = self.auxiliary.get_api_generate_model().generate(messages)
        self.logger.info(f"Query {representation_type} representation Response:\n{response}")
        data = extract_json(response)
        return data

    def _get_existing_concepts(self, concepts: Dict[str, Dict]) -> Dict:
        concept_summary = {}
        for concept_id, concept in concepts.items():
            concept_summary[concept_id] = {
                'terms': concept.get('original_terms', []) + concept.get('synonyms', []),
                'type': concept.get('type', ''),
                'representations': list(concept.get('representations', {}).keys())
            }
        return concept_summary

    def _get_cache_filepath(self, category: str):
        cache_dir = self.config.get('cache_dir', None)
        if not cache_dir:
            raise ValueError("Unknown cache dir")
        
        symbols_dir = os.path.join(cache_dir, 'symbols')
        os.makedirs(symbols_dir, exist_ok=True)

        category_hash = hashlib.md5(category.encode('utf-8')).hexdigest()[:8]
        return os.path.join(symbols_dir, f"{category_hash}.jsonl")

    def _load_cache_for_category(self, category: str):
        if category in self.category_cache:
            return
            
        self.category_cache[category] = {
            'concepts': {},
            'term_to_concept': {},
            'used_symbols': {
                self.SYMBOLIC: set(),
                self.LOGICAL: set(),
                self.MATHEMATICAL: set(),
                self.DOMAIN_SPECIFIC: set()
            }
        }
        
        cache_file = self._get_cache_filepath(category)
        
        if not os.path.exists(cache_file):
            return
        
        try:
            with open(cache_file, 'r', encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    if line:
                        data = json.loads(line)
                        concept_id = data['concept_id']
                        self.category_cache[category]['concepts'][concept_id] = data['concept_info']
                        
                        all_terms = data['concept_info'].get('original_terms', []) + \
                                   data['concept_info'].get('synonyms', [])
                        for term in all_terms:
                            self.category_cache[category]['term_to_concept'][term] = concept_id
                        
                        representations = data['concept_info'].get('representations', {})
                        for rep_type, rep_info in representations.items():
                            symbol = rep_info.get('symbol')
                            if symbol and rep_type in self.category_cache[category]['used_symbols']:
                                self.category_cache[category]['used_symbols'][rep_type].add(symbol)
                        
        except Exception as e:
            self.logger.log_exception(e)

    def _save_cache_for_category(self, category: str):
        cache_file = self._get_cache_filepath(category)
        
        try:
            with open(cache_file, 'w', encoding='utf-8') as f:
                for concept_id, concept_info in self.category_cache[category]['concepts'].items():
                    record = {
                        'concept_id': concept_id,
                        'concept_info': concept_info
                    }
                    f.write(json.dumps(record, ensure_ascii=False) + '\n')
                    
        except Exception as e:
            self.logger.log_exception(e)

    def replace_terms_in_text(self, text: str, representation_type: str, category: str) -> str:
        if category not in self.category_cache:
            self._load_cache_for_category(category)
            
        term_to_concept = self.category_cache[category]['term_to_concept']
        concepts = self.category_cache[category]['concepts']
        
        result_text = text
        for term, concept_id in term_to_concept.items():
            if concept_id in concepts:
                concept = concepts[concept_id]
                representations = concept.get('representations', {})
                if representation_type in representations:
                    symbol = representations[representation_type].get('symbol', '')
                    if symbol and term in result_text:
                        result_text = result_text.replace(term, symbol)
        return result_text

    def get_concept_info(self, term: str, category: str) -> Optional[Dict]:
        if category not in self.category_cache:
            self._load_cache_for_category(category)
            
        term_to_concept = self.category_cache[category]['term_to_concept']
        concepts = self.category_cache[category]['concepts']
        
        concept_id = term_to_concept.get(term)
        if concept_id and concept_id in concepts:
            return concepts[concept_id].copy()
        return None

    def get_all_concepts(self, category: str) -> Dict:
        if category not in self.category_cache:
            self._load_cache_for_category(category)
        return self.category_cache[category]['concepts'].copy()
    
    def get_concepts_count(self, category: str) -> int:
        if category not in self.category_cache:
            self._load_cache_for_category(category)
        return len(self.category_cache[category]['concepts'])

    def has_term(self, term: str, category: str) -> bool:
        if category not in self.category_cache:
            self._load_cache_for_category(category)
        return term in self.category_cache[category]['term_to_concept']
